# Single node training a GPT style model with Nvidia NeMo
#
# This script downloads data from read-only bucket at gs://sky-wiki-data.
# If you want to preprocess the data yourself, see nemo_gpt_preprocessing.yaml.
#
# The specific model used here should fit on GPU with 16GB memory.
#
# After the script completes, the model checkpoints will be saved in
# ~/sky_workdir/nemo_experiments/megatron_gpt/checkpoints on the head node.
#
# Usage:
#   sky launch -s -c nemo_gpt nemo_gpt_singlenode.yaml
#
#   # Or try on spot A100 GPUs:
#   sky launch -c nemo_gpt nemo_gpt_singlenode.yaml --use-spot --gpus A100:1
#
#   # The setup will take some time (~1 hr), feel free to ctrl-c once the setup script starts
#   # You can reconnect to log stream using `sky logs nemo_gpt_train`
#
#   # Terminate cluster after you're done
#   sky down nemo_gpt

resources:
  cpus: 6+
  accelerators: A100:1

num_nodes: 1

envs:
  DATASET_ROOT: $HOME/wiki/

setup: |
  # ============== Dependency Setup ==============
  conda activate nemo
  if [ $? -eq 0 ]; then
      echo "Nemo conda env exists"
  else
      echo "Setup start"
  
      conda create -y --name nemo python==3.10.12
      conda activate nemo
  
      # Install PyTorch
      pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
      
      # Install nemo
      git clone https://github.com/NVIDIA/NeMo.git
      cd NeMo
      git checkout b4ad7eaa7873d632391d6985aa6b359f39c20bab
      pip install Cython
      pip install .[all]
      cd ..
  
      # Install megatron-core
      # We install in editable mode because setup.py does not install all 
      # required modules if we install in non-editable mode.
      git clone https://github.com/NVIDIA/Megatron-LM
      cd Megatron-LM
      git checkout dc21350806361564b8ce61d4a8d247cb195cc5f0
      pip install -e .  
      cd ..
      
      # Install ninja for faster compilation
      pip install ninja packaging
  
      # Install transformer engine and flash-attn (Takes ~1hr to compile)
      MAX_JOBS=4 pip install flash-attn==2.0.4 --no-build-isolation # Version upper capped by TransformerEngine
      MAX_JOBS=4 pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
      
      pip install pytorch-extension
  
      # Install Apex
      git clone https://github.com/NVIDIA/apex.git
      cd apex
      git checkout 52e18c894223800cb611682dce27d88050edf1de
      pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
      cd ..
        
      # Install gsutil if it doesn't exist
      if ! command -v gsutil &> /dev/null
      then
          pip install gsutil
      else
          echo "gsutil exists"
      fi
  fi

run: |
  conda activate nemo
  # ============= Data Download =============
  # We download pre-processed data from a read-only bucket at gs://sky-wiki-data
  # For more on how to pre-process data, see nemo_gpt3_preprocessing.yaml
  
  if [ -f ${DATASET_ROOT}/hfbpe_gpt_training_data_text_document.bin ]; then
      echo "Data already downloaded"
  else
      echo "Head node downloading data to shared bucket."
      mkdir -p $DATASET_ROOT
      gsutil -m cp gs://sky-wiki-data/{gpt2-merges.txt,gpt2-vocab.json,hfbpe_gpt_training_data_text_document.bin,hfbpe_gpt_training_data_text_document.idx} ${DATASET_ROOT}
  fi 
  
  # ============= Training =============    
  python NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py  \
    --config-path=conf \
    --config-name=megatron_gpt_config \
    trainer.devices=${SKYPILOT_NUM_GPUS_PER_NODE} \
    trainer.num_nodes=1 \
    trainer.max_epochs=null \
    trainer.max_steps=300000 \
    trainer.val_check_interval=300 \
    trainer.log_every_n_steps=50 \
    trainer.limit_val_batches=50 \
    trainer.limit_test_batches=50 \
    trainer.accumulate_grad_batches=1 \
    trainer.precision=16 \
    model.micro_batch_size=6 \
    model.global_batch_size=192 \
    model.tensor_model_parallel_size=1 \
    model.pipeline_model_parallel_size=1 \
    model.max_position_embeddings=1024 \
    model.encoder_seq_length=1024 \
    model.hidden_size=768 \
    model.ffn_hidden_size=3072 \
    model.num_layers=12 \
    model.num_attention_heads=12 \
    model.init_method_std=0.021 \
    model.hidden_dropout=0.1 \
    model.layernorm_epsilon=1e-5 \
    model.tokenizer.vocab_file=${DATASET_ROOT}/gpt2-vocab.json \
    model.tokenizer.merge_file=${DATASET_ROOT}/gpt2-merges.txt \
    model.data.data_prefix=[1.0,${DATASET_ROOT}/hfbpe_gpt_training_data_text_document] \
    model.data.num_workers=2 \
    model.data.seq_length=1024 \
    model.data.splits_string=\'980,10,10\' \
    model.optim.name=fused_adam \
    model.optim.lr=6e-4 \
    model.optim.betas=[0.9,0.95] \
    model.optim.weight_decay=0.1 \
    model.optim.sched.name=CosineAnnealing \
    model.optim.sched.warmup_steps=750 \
    model.optim.sched.constant_steps=80000 \
    model.optim.sched.min_lr=6e-5 \
    exp_manager.resume_if_exists=True \
    exp_manager.resume_ignore_no_checkpoint=True \
    exp_manager.create_checkpoint_callback=True \
    exp_manager.checkpoint_callback_params.monitor=val_loss \
    exp_manager.checkpoint_callback_params.save_top_k=3 \
    exp_manager.checkpoint_callback_params.mode=min \
    exp_manager.checkpoint_callback_params.always_save_nemo=False
